import keras
from keras import ops

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.sam.sam_layers import (
    RandomFrequencyPositionalEmbeddings,
)


@keras_hub_export("keras_hub.layers.SAMPromptEncoder")
class SAMPromptEncoder(keras.layers.Layer):
    """Prompt Encoder for the Segment Anything Model (SAM).

    The prompt encoder generates encodings for three types of prompts:
    - Point prompts: Points on the image along with a label indicating whether
        the point is in the foreground (part of the mask) or in the background
        (not a part of the mask).
    - Box prompts: A batch of bounding boxes with format [(x1, y1), (x2, y2)]
        used to determine the location of the masks in the image.
    - Masks: An input mask can be passed to refine the positional embeddings
        for the output mask.

    First, the point prompts and box prompts are concatenated and positional
    encodings are generated using random spatial frequencies. A point is
    represented as the sum of a positional encoding of the point's location
    and one of two learned embeddings that indicate if the point is either in
    the foreground or background. A box is represented by an embedding pair:
    (1) the positional encoding of its top-left corner summed with a learned
    embedding representing "top-left corner" and
    (2) the same structure but using a learned embedding indicating
    "bottom-right corner".
    The box and point encodings are referred to as "prompt_sparse encodings"
    If a mask prompt is passed, a convolutional neural net is used to
    downscale it to generate "dense encodings". If no mask prompt is passed,
    an embedding layer is used instead to generate a "no mask" embedding.


    Args:
        hidden_size: int, optional. The number of features in the output
            embeddings. Defaults to `256`.
        image_embedding_size: int, optional. The number of features in the
            image embeddings generated by an image encoder. Defaults to
            `(64, 64)`.
        input_image_size: tuple[int], optional. A tuple of the height and
            width of the image being prompted. Defaults to `(1024, 1024)`.
        mask_in_channels: int, optional. The number of channels of the mask
            prompt. Defaults to `16`.
        activation: str, optional. The activation to use in the mask
            downscaler neural net. Defaults to `"gelu"`.
    """

    def __init__(
        self,
        *,
        hidden_size=256,
        image_embedding_size=(64, 64),
        input_image_size=(1024, 1024),
        mask_in_channels=16,
        activation="gelu",
        **kwargs
    ):
        super().__init__(**kwargs)
        self.hidden_size = hidden_size
        self.image_embedding_size = image_embedding_size
        self.input_image_size = input_image_size
        self.mask_in_channels = mask_in_channels
        self.activation = activation

        self.positional_embedding_layer = RandomFrequencyPositionalEmbeddings(
            num_positional_features=self.hidden_size // 2, scale=1
        )

        self.foreground_point_embed = keras.layers.Embedding(
            1, hidden_size, name="foreground_point_embed"
        )
        self.background_point_embed = keras.layers.Embedding(
            1, hidden_size, name="background_point_embed"
        )
        self.top_left_corner_embed = keras.layers.Embedding(
            1, hidden_size, name="top_left_corner_embed"
        )
        self.bottom_right_corner_embed = keras.layers.Embedding(
            1, hidden_size, name="bottom_right_corner_embed"
        )
        self.not_a_point_embed = keras.layers.Embedding(
            1, hidden_size, name="not_a_point_embed"
        )

        self.mask_downscaler = keras.models.Sequential(
            [
                keras.layers.Conv2D(
                    mask_in_channels // 4, kernel_size=2, strides=2
                ),
                keras.layers.LayerNormalization(epsilon=1e-6),
                keras.layers.Activation(activation),
                keras.layers.Conv2D(mask_in_channels, kernel_size=2, strides=2),
                keras.layers.LayerNormalization(epsilon=1e-6),
                keras.layers.Activation(activation),
                keras.layers.Conv2D(hidden_size, kernel_size=1),
            ],
            name="mask_downscaler",
        )
        self.no_mask_embed = keras.layers.Embedding(
            1, hidden_size, name="no_mask_embed"
        )

    def build(
        self,
        points_shape=None,
        labels_shape=None,
        boxes_shape=None,
        masks_shape=None,
    ):
        self.positional_embedding_layer.build()
        for layer in [
            self.foreground_point_embed,
            self.background_point_embed,
            self.top_left_corner_embed,
            self.bottom_right_corner_embed,
            self.not_a_point_embed,
            self.no_mask_embed,
        ]:
            layer.build([None])
        self.mask_downscaler.build(
            [
                None,
                4 * self.image_embedding_size[0],
                4 * self.image_embedding_size[1],
                1,
            ]
        )
        self.built = True

    def compute_output_shape(
        self,
        points_shape=None,
        labels_shape=None,
        boxes_shape=None,
        masks_shape=None,
    ):
        batch_size = None
        for shape in (points_shape, labels_shape, boxes_shape, masks_shape):
            if shape is not None:
                batch_size = shape[0]
                break
        return {
            "prompt_sparse_embeddings": (
                batch_size,
                None,
                self.hidden_size,
            ),
            "prompt_dense_embeddings": (
                batch_size,
                self.image_embedding_size[0],
                self.image_embedding_size[1],
                self.hidden_size,
            ),
            "prompt_dense_positional_embeddings": (
                batch_size,
                self.image_embedding_size[0],
                self.image_embedding_size[1],
                self.hidden_size,
            ),
        }

    def _embed_points(self, points, labels):
        points = points + 0.5
        indices = ops.arange(1, dtype="int32")

        point_embeddings = self.positional_embedding_layer.encode_coordinates(
            points, self.input_image_size
        )
        labels = ops.broadcast_to(
            labels[..., None], ops.shape(point_embeddings)
        )
        point_embeddings = ops.where(
            labels == 0,
            point_embeddings + self.background_point_embed(indices),
            point_embeddings + self.foreground_point_embed(indices),
        )
        point_embeddings = ops.where(
            labels == -1,
            self.not_a_point_embed(indices),
            point_embeddings,
        )
        return point_embeddings

    def _embed_box(self, box):
        shape = ops.shape(box)
        batch_size, N = shape[0], shape[1]
        box = box + 0.5
        indices = ops.arange(1, dtype="int32")
        corner_embedding = self.positional_embedding_layer.encode_coordinates(
            box, self.input_image_size
        )
        top_left_embedding = corner_embedding[
            :, :, 0, :
        ] + self.top_left_corner_embed(indices)
        bottom_right_embedding = corner_embedding[
            :, :, 1, :
        ] + self.bottom_right_corner_embed(indices)
        corner_embedding = ops.stack(
            [top_left_embedding, bottom_right_embedding], axis=2
        )
        return ops.reshape(
            corner_embedding, (batch_size, N * 2, self.hidden_size)
        )

    def _embed_mask(self, mask):
        mask_embedding = self.mask_downscaler(mask)
        return mask_embedding

    def call(
        self, images=None, points=None, labels=None, boxes=None, masks=None
    ):
        # Get the batch shape based on any arbitrary input, because batch
        # shapes must all match.
        valid_inputs = [
            x for x in (points, labels, boxes, masks) if x is not None
        ]

        batch_size = ops.shape(valid_inputs[0])[0]
        if points is None:
            points = ops.zeros((batch_size, 0, 2))
        if labels is None:
            labels = ops.zeros((batch_size, 0))
        if boxes is None:
            boxes = ops.zeros((batch_size, 0, 2, 2))
        if masks is None:
            masks = ops.zeros((batch_size, 0, 256, 256, 1))

        # Compute point embeddings
        point_embeddings = self._embed_points(points, labels)

        # Compute box embeddings
        box_embeddings = self._embed_box(boxes)

        # Concatenate both into a sparse embeddings tensor
        sparse_embeddings = ops.concatenate(
            [point_embeddings, box_embeddings], axis=1
        )

        # Compute the mask embeddings
        def _no_mask_embed():
            reshaped_embed = ops.reshape(
                self.no_mask_embed(ops.arange(1, dtype="int32")),
                (1, 1, 1, self.hidden_size),
            )
            broadcasted_embed = ops.broadcast_to(
                reshaped_embed,
                shape=(
                    batch_size,
                    self.image_embedding_size[0],
                    self.image_embedding_size[1],
                    self.hidden_size,
                ),
            )
            return broadcasted_embed

        def _maybe_input_mask_embed():
            # Keras passes the masks as concrete tensors for both the
            # true and false functions to build the output shape. So, we
            # need to handle the case when 0 size masks is passed and
            # dispatch the call to `_no_mask_embed`. Note that we can't call
            # the lambda directly since the inputs are bound to different
            # values when called with concrete values.
            if masks.shape[1] == 0:
                return ops.broadcast_to(
                    ops.reshape(
                        self.no_mask_embed(ops.arange(1, dtype="int32")),
                        (1, 1, 1, self.hidden_size),
                    ),
                    shape=(
                        batch_size,
                        self.image_embedding_size[0],
                        self.image_embedding_size[1],
                        self.hidden_size,
                    ),
                )
            shape = ops.shape(masks)
            BM, N, height, width, channels = (
                shape[0],
                shape[1],
                shape[2],
                shape[3],
                shape[4],
            )
            return self._embed_mask(
                ops.reshape(masks, (BM * N, height, width, channels))
            )

        dense_embeddings = ops.cond(
            ops.equal(ops.size(masks), 0),
            _no_mask_embed,
            _maybe_input_mask_embed,
        )

        # Compute the dense positional embeddings
        prompt_dense_positional_embeddings = (
            self.positional_embedding_layer.encode_image(
                self.image_embedding_size
            )[None, ...]
        )

        return {
            "prompt_sparse_embeddings": sparse_embeddings,
            "prompt_dense_embeddings": dense_embeddings,
            "prompt_dense_positional_embeddings": prompt_dense_positional_embeddings,
        }

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "hidden_size": self.hidden_size,
                "image_embedding_size": self.image_embedding_size,
                "input_image_size": self.input_image_size,
                "mask_in_channels": self.mask_in_channels,
                "activation": self.activation,
            }
        )
        return config
